import os
from glob import glob
import numpy as np
import cv2
import torch
from align_ply_from_ape_log import load_gaussians_from_ply
import sys
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT, '../'))

from vigs.gaussian.utils.graphics_utils import getProjectionMatrix2
import lietorch
from decimal import Decimal
import json
import os

import cv2
import numpy as np
# np.random.seed(42)   # fix seed
import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from vigs.util.utils import Log
from vigs.gaussian.renderer import render
from vigs.gaussian.utils.loss_utils import ssim, psnr
from vigs.gaussian.utils.camera_utils import Camera

seed = 0
# random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def eval_rendering(
    gtimages,
    gtdepthdir,
    traj,
    gaussians,
    save_dir,
    subfolder_name,
    background,
    projection_matrix,
    K,
    # kf_idx,
    iteration="final",
):
    gtdepths = sorted(os.listdir(gtdepthdir)) if gtdepthdir is not None else None
    psnr_array, ssim_array, lpips_array, l1_array = [], [], [], []
    cal_lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True).to("cuda")
    image_save_dir = f'{save_dir}/{subfolder_name}/image_{iteration}'
    depth_save_dir = f'{save_dir}/{subfolder_name}/depth_{iteration}'
    gt_image_save_dir = f'{save_dir}/{subfolder_name}/gt_image_{iteration}'
    diff_image_save_dir = f'{save_dir}/{subfolder_name}/diff_image_{iteration}'
    print("image_save_dir: ", image_save_dir)
    
    # vis_save_dir = f'{save_dir}/renders/vis_{iteration}'  
    os.makedirs(image_save_dir, exist_ok=True)
    os.makedirs(depth_save_dir, exist_ok=True)
    os.makedirs(gt_image_save_dir, exist_ok=True)
    os.makedirs(diff_image_save_dir, exist_ok=True)
    # os.makedirs(vis_save_dir, exist_ok=True)
    
    for i, (idx, image) in enumerate(gtimages.items()):
        # if idx % 5 != 0 and idx not in kf_idx and i != len(gtimages) - 1:
        #     continue
        frame = Camera.init_from_tracking(image.squeeze()/255.0, None, None, traj[idx], idx, projection_matrix, K)
        gtimage = frame.original_image.cuda()

        rendering = render(frame, gaussians, background)
        image = torch.clamp(rendering["render"], 0.0, 1.0)
        depth = rendering["depth"].detach().squeeze().cpu().numpy()

        if gtdepthdir is not None:
            # TODO: add png scale 
            gtdepth = cv2.imread(os.path.join(gtdepthdir, gtdepths[idx]), cv2.IMREAD_ANYDEPTH) / 6553.5 # 1000.
            gtdepth = cv2.resize(gtdepth, (depth.shape[-1], depth.shape[-2]), interpolation=cv2.INTER_NEAREST)
            invalid = gtdepth <= 0
            depth[invalid] = 0

        pred = (image.detach().cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)
        pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
        
        gtimage_np = (gtimage.detach().cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)
        gtimage_np = cv2.cvtColor(gtimage_np, cv2.COLOR_BGR2RGB)
        
        # diff_image = np.abs(gtimage_np.astype(np.float32) - pred.astype(np.float32)).astype(np.uint8)
        # gtimage_np.astype(np.float32) - pred.astype(np.float32)
        diff_image = np.abs(gtimage_np.astype(np.float32) - pred.astype(np.float32)).astype(np.uint8)
        # breakpoint()
        # diff_image = cv2.cvtColor(diff_image, cv2.COLOR_BGR2RGB)
        
        cv2.imwrite(f'{image_save_dir}/{idx:06d}.jpg', pred)
        cv2.imwrite(f'{gt_image_save_dir}/{idx:06d}.jpg', gtimage_np)
        cv2.imwrite(f'{diff_image_save_dir}/{idx:06d}.jpg', diff_image)
        cv2.imwrite(f'{depth_save_dir}/{idx:06d}.png', np.clip(depth*6553.5, 0, 65535).astype(np.uint16))
        
        # vis = np.concatenate((pred, cv2.imread(f'{save_dir}/renders/depth_{iteration}/{idx:06d}.png')), axis=0)
        # cv2.imwrite(f'{vis_save_dir}/{idx:06d}.jpg', vis)

        # if gtdepthdir is not None and idx in kf_idx:
        #     # TODO: add scale alignment 
        #     l1_array.append(np.abs(gtdepth[depth > 0] - depth[depth > 0]).mean().item()) 

        # if idx in kf_idx:
        #     continue
        mask = gtimage > 0
        psnr_score = psnr((image[mask]).unsqueeze(0), (gtimage[mask]).unsqueeze(0))
        ssim_score = ssim((image).unsqueeze(0), (gtimage).unsqueeze(0))
        lpips_score = cal_lpips((image).unsqueeze(0), (gtimage).unsqueeze(0))

        psnr_array.append(psnr_score.item())
        ssim_array.append(ssim_score.item())
        lpips_array.append(lpips_score.item())
        # print(idx, psnr_score.item(), ssim_score.item(), lpips_score.item())

    output = dict()
    output["mean_psnr"] = float(np.mean(psnr_array))
    output["mean_ssim"] = float(np.mean(ssim_array))
    output["mean_lpips"] = float(np.mean(lpips_array))
    output["mean_l1"] = float(np.mean(l1_array)) if l1_array else 0

    Log(f'mean psnr: {output["mean_psnr"]}, ssim: {output["mean_ssim"]}, lpips: {output["mean_lpips"]}, depth l1: {output["mean_l1"]}', tag="Eval")

    psnr_save_dir = os.path.join(save_dir, "psnr", str(iteration))
    os.makedirs(psnr_save_dir, exist_ok=True)

    json.dump(
        output,
        open(os.path.join(psnr_save_dir, "final_result.json"), "w", encoding="utf-8"),
        indent=4,
    )
    return output

def load_traj_file(traj_file):
        traj = np.loadtxt(traj_file, dtype=str)
        # print('NUM OF TRAJ: ', len(traj))
        timestamp = np.array([int((Decimal(s) * Decimal('1e9')).to_integral_value()) for s in traj[:, 0]],
                     dtype=np.int64)
        # timestamp = traj[:,0]
        
        traj = lietorch.SE3(torch.from_numpy(traj[:,1:].astype(np.float64)).to(device='cpu', dtype=torch.float64).contiguous()).inv().matrix().data

        return timestamp, traj

def process_results(results, name):
    if results == []:
        return
    PSNR = '& \\psnr'
    SSIM = "& \\ssim"
    LPIPS = "& \\lpips"
    L1 = "& \\l1"
    sum_PSNR = 0.0
    sum_SSIM = 0.0
    sum_LPIPS = 0.0
    sum_L1 = 0.0
    for result in results:
        sum_PSNR += result['mean_psnr']
        sum_SSIM += result['mean_ssim']
        sum_LPIPS += result['mean_lpips']
        sum_L1 += result['mean_l1']
        PSNR += f" & {result['mean_psnr']:.2f}"
        SSIM += f" & {result['mean_ssim']:.3f}"
        LPIPS += f" & {result['mean_lpips']:.3f}"
        L1 += f" & {result['mean_l1']:.3f}"
    sum_PSNR /= len(results) # change to mean
    sum_SSIM /= len(results)
    sum_LPIPS /= len(results)
    sum_L1 /= len(results)
    PSNR += f" & {sum_PSNR:.2f} \\\\"
    SSIM += f" & {sum_SSIM:.3f} \\\\"
    LPIPS += f" & {sum_LPIPS:.3f} \\\\"
    L1 += f" & {sum_L1:.3f}"
    print(name)
    print(PSNR)
    print(SSIM)
    print(LPIPS)
    print(L1)
    # print(f"{name} PSNR: {PSNR}, SSIM: {SSIM}, LPIPS: {LPIPS}, L1: {L1}")

if __name__ == "__main__":
    stride = 1
    basefolder_splat_slam = f'/home/zihzhu/data/GS_VIO_SLAM/splat-slam_output/rpngar'
    basefolder_hislam2 = f'/home/zihzhu/data/GS_VIO_SLAM/hislam2_output_for_finalBA/rpng_batch_eval_no_imu_stride1'
    seqs = sorted(glob('/home/zihzhu/data/Datasets/rpngar/*'))
    output_folder = f"/home/zihzhu/data/output_rpng/finalBA_GSrefinement" 
    
    
    background = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda")

    ours_results = []
    splat_slam_results = []
    hislam2_results = []
    ours_finalBA_results = []
    splat_slam_finalBA_results = []
    hislam2_finalBA_results = []
    
    for i, seq in enumerate(seqs[:]):
        if 'TimeStamps' in seq:
            continue
        if 'pgt' in seq:
            continue
        if '.py' in seq:
            continue   
        if '.sh' in seq:
            continue    
        if 'stride' in seq:
            continue
        print("Processing: ", seq)
        name = os.path.basename(seq)
        
        save_dir = f'{output_folder}/{name}'
        timestamp, traj = load_traj_file(f'{output_folder}/{name}/traj_full_beforeBA.txt')
        try:
            _, traj_afterBA = load_traj_file(f'{output_folder}/{name}/traj_full_afterBA.txt')
        except Exception as e:
            
            pass
        
        # kf_traj_file = f'{output_folder}/{name}/traj_kf_beforeBA.txt'
        kf_timestamp, _ = load_traj_file(f'{output_folder}/{name}/traj_kf_beforeBA.txt')
        
        folder_splat_slam = f'{basefolder_splat_slam}/{name}'

        splat_slam_kf_file = f'{folder_splat_slam}/before_final_ba/traj/kf_traj_kf_est.tum'
        splat_slam_kf_traj = np.loadtxt(splat_slam_kf_file)
        splat_slam_kf_idx = splat_slam_kf_traj[:,0].astype(int)
        # splat_slam_kf_timestamp, _ = load_traj_file(f'{folder_splat_slam}/before_final_ba/traj/kf_traj_kf_est.tum')

        _, splat_slam_traj = load_traj_file(f'{folder_splat_slam}/before_final_ba/traj/full_traj_full_est_aligned.tum')
        _, splat_slam_traj_afterBA = load_traj_file(f'{folder_splat_slam}/traj/full_traj_full_est_aligned.tum')
        
        
        # splat_slam_traj_file = f'{folder_splat_slam}/before_final_ba/traj/kf_traj_kf_est.tum'
        splat_slam_full_traj = np.loadtxt(f'{folder_splat_slam}/before_final_ba/traj/full_traj_full_est_aligned.tum')
        splat_slam_full_idx = splat_slam_full_traj[:,0].astype(int)
        
        
        folder_hislam2 = f'{basefolder_hislam2}/{name}'
        if 'table_05' in seq:
            hislam2_kf_timestamp = kf_timestamp
        else:
            hislam2_kf_timestamp, _ = load_traj_file(f'{folder_hislam2}/traj_kf_beforeBA.txt')
            try:
                _, hislam2_traj = load_traj_file(f'{folder_hislam2}/traj_full_beforeBA.txt')
            except Exception as e:
                print(e)
                pass
            try:
                _, hislam2_traj_afterBA = load_traj_file(f'{folder_hislam2}/traj_full_afterBA.txt')
            except Exception as e:
                print(e)
                pass

        idxs_selected = []
        perm = np.random.permutation(traj.shape[0])
        idxs_selected_splat_slam=[]
        for candidate_idx in perm:
            candidate_ts = timestamp[candidate_idx].item()
    
            if (candidate_ts not in kf_timestamp) and (candidate_idx not in splat_slam_kf_idx) and (candidate_idx in splat_slam_full_idx) \
                and (candidate_ts not in hislam2_kf_timestamp):
                idxs_selected.append(candidate_idx.item())
                # breakpoint()
                idx_selected_splat_slam = int(np.argwhere(splat_slam_full_idx == candidate_idx)[0])
                idxs_selected_splat_slam.append(idx_selected_splat_slam)
                # index_selected_splat_slam.append(splat_slam_traj_timestamp.find(candidate_ts))
            if len(idxs_selected) == 50:
                break
        try:
            traj = traj[idxs_selected]
            splat_slam_traj = splat_slam_traj[idxs_selected_splat_slam]
            splat_slam_traj_afterBA = splat_slam_traj_afterBA[idxs_selected_splat_slam]
            traj_afterBA = traj_afterBA[idxs_selected]
            hislam2_traj_afterBA = hislam2_traj_afterBA[idxs_selected]
            hislam2_traj = hislam2_traj[idxs_selected] # for hislam2, conflict between before and after BA full traj, so this will have error, make it the very end 
        except Exception as e:
            print(e)
            pass
        # breakpoint()
        
        if 'rpngar' in seq:
            intrinsics_file = 'calib/rpngar.txt'
        elif 'livo2' in seq:
            intrinsics_file = f'{seq}/intrinsics.txt'
        elif 'UTMM' in seq:
            intrinsics_file = f'{seq}/intrinsics_ours.txt'
        else:
            raise ValueError(f"Unknown dataset: {seq}")
        calib = np.loadtxt(intrinsics_file, delimiter=" ")
        
        imagedir = f'{seq}/rgb'
        image_files = sorted(glob(os.path.join(imagedir, '*.png')))
        print('NUM OF IMAGES: ', len(image_files))
        images = {}
        # breakpoint()
        # for idx, image_file in enumerate(image_files[index_selected]):
        for j, idx in enumerate(idxs_selected):
            image_file = image_files[idx]
            # print(image_file)
            image = cv2.imread(image_file)
            H, W = image.shape[:2]
            if len(calib) > 4:
                # cv2.imwrite('tmp_before.jpg', image)
                K = np.array([[calib[0], 0, calib[2]],[0, calib[1], calib[3]],[0,0,1]])
                image = cv2.undistort(image, K, calib[4:])
                # cv2.imwrite('tmp_after.jpg', image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = torch.as_tensor(image).permute(2, 0, 1)
            images[j] = image

        # K = np.array([[calib[0], 0, calib[2]],[0, calib[1], calib[3]],[0,0,1]])
        K = list(calib[:4]) + [W, H]
        projection_matrix = getProjectionMatrix2(znear=0.01, zfar=100.0, fx=K[0], fy=K[1], cx=K[2], cy=K[3], W=W, H=H).transpose(0, 1).cuda()
        
        gaussians=load_gaussians_from_ply(f'{output_folder}/{name}/3dgs_before_final.ply')

        output=eval_rendering(images, None, traj.cuda(), gaussians, save_dir, "render_before_final_est_pose", background, projection_matrix, K, iteration="after_opt")
        ours_results.append(output)

        gaussians=load_gaussians_from_ply(f'{folder_splat_slam}/point_cloud/iteration_before_refine_aligned/point_cloud.ply')
        
        output=eval_rendering(images, None, splat_slam_traj.cuda(), gaussians, folder_splat_slam, "render_before_final_est_pose", background, projection_matrix, K, iteration="after_opt")
        splat_slam_results.append(output)
        
        try:
            gaussians=load_gaussians_from_ply(f'{folder_hislam2}/3dgs_before_final.ply')
            output=eval_rendering(images, None, hislam2_traj.cuda(), gaussians, folder_hislam2, "render_before_final_est_pose", background, projection_matrix, K, iteration="after_opt")
            hislam2_results.append(output)
        except Exception as e:
            print(e)
            pass

        try:
            gaussians=load_gaussians_from_ply(f'{output_folder}/{name}/3dgs_final.ply')

            output=eval_rendering(images, None, traj_afterBA.cuda(), gaussians, save_dir, "render_after_final_est_pose", background, projection_matrix, K, iteration="after_opt")
            ours_finalBA_results.append(output)
        except:
            pass


        gaussians=load_gaussians_from_ply(f'{folder_splat_slam}/point_cloud/iteration_after_refine_aligned/point_cloud.ply')
        
        # print(splat_slam_traj_afterBA)
        output=eval_rendering(images, None, splat_slam_traj_afterBA.cuda(), gaussians, folder_splat_slam, "render_after_final_est_pose", background, projection_matrix, K, iteration="after_opt")
        splat_slam_finalBA_results.append(output)

        try:
            gaussians=load_gaussians_from_ply(f'{folder_hislam2}/3dgs_final.ply')
            output=eval_rendering(images, None, hislam2_traj_afterBA.cuda(), gaussians, folder_hislam2, "render_after_final_est_pose", background, projection_matrix, K, iteration="after_opt")
            hislam2_finalBA_results.append(output)
        except:
            pass
        
    process_results(ours_results, "Ours")
    process_results(splat_slam_results, "Splat-SLAM")
    process_results(hislam2_results, "HI-SLAM2")
    process_results(ours_finalBA_results, "Ours w/ refinement")
    process_results(splat_slam_finalBA_results, "Splat-SLAM w/ refinement")
    process_results(hislam2_finalBA_results, "HI-SLAM2 w/ refinement")




        # eval_rendering(gtimages, gtdepthdir, traj, self.gaussians,self.save_dir, self.background,
            # self.projection_matrix, self.K, kf_idx, iteration="after_opt")